"""Plot experiment results."""

import json
import os

from absl import app
from absl import flags
import matplotlib.pyplot as plt
import numpy as np
from sklearn.utils.extmath import randomized_svd

plt.style.use('seaborn')

flags.DEFINE_string('input_dir', 'experiments', 'Dir with measurements')
flags.DEFINE_string('output_file', None, 'File with plots')
FLAGS = flags.FLAGS
DS = ['mnist', 'fashion_mnist', 'smallnorb', 'colorectal_histology']


def plot(ax, results, title, start_rank=0):
  algos = ['em', 'svd_w+em']
  for algo in algos:
    result = results[algo]
    iters = list(range(len(result)))
    ax.plot(iters[start_rank:], result[start_rank:], linewidth=3, label=algo)

  ax.set_ylabel('Loss', fontsize=18)
  ax.set_xlabel('Iterations', fontsize=18)
  ax.legend(fontsize=16)
  ax.set_title(f'{title} iter [{start_rank}:50]', fontsize=20)


def main(argv) -> None:
  results = {}
  for ds in DS:
    input_file = os.path.join(FLAGS.input_dir, f'{ds}-em-rank-20.json')
    if not os.path.exists(input_file):
      raise ValueError(f'Path {input_file} does not exist')
    with open(input_file, 'r') as fp:
      results[ds] = json.load(fp)

  plt.rc('xtick', labelsize=14)    # fontsize of the tick labels
  plt.rc('ytick', labelsize=14)    # fontsize of the tick labels
  fig = plt.figure(figsize=(20, 6))
  axs = fig.subplots(nrows=2, ncols=4)

  plot(axs[0][0], results['mnist'], 'mnist')
  plot(axs[0][1], results['mnist'], 'mnist', start_rank=20)
  plot(axs[0][2], results['fashion_mnist'], 'fashion_mnist')
  plot(axs[0][3], results['fashion_mnist'], 'fashion_mnist', start_rank=20)
  plot(axs[1][0], results['smallnorb'], 'smallnorb')
  plot(axs[1][1], results['smallnorb'], 'smallnorb', start_rank=20)
  plot(axs[1][2], results['colorectal_histology'], 'colorectal_histology')
  plot(axs[1][3], results['colorectal_histology'], 'colorectal_histology', start_rank=20)

  fig.tight_layout(h_pad=3, w_pad=3)

  if FLAGS.output_file is not None:
    plt.savefig(FLAGS.output_file, bbox_inches='tight')
  else:
    plt.show()


if __name__ == '__main__':
  app.run(main)
